"""
Replication Script for "Currency Development Through Liquidity Provision"
--------------------------------------------------
This script processes Chinese government USD bond yields and US Treasury yields, performs interpolation, and plots yield spreads for replication of "Currency Development Through Liquidity Provision" by Antonio Coppola, Arvind Krishnamurthy, and Chenzi Xu (2025).

Data sources:
- cn_gov_usd_bond_mapping.xlsx (Chinese government USD bond information)
- cn_gov_usd_bond_yld_bbg.xlsx (Chinese government USD bond yield data from Bloomberg)
- us_treasury_benchmark_bbg.xlsx (US Treasury yield data from Bloomberg)

Author: Zhongyu Yin
Date: April 2025
"""

# ============================
# Imports and Settings
# ============================
import os
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from functools import reduce
from matplotlib.dates import DateFormatter

# Plotting settings
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

# ============================
# Paths
# ============================
PROJECT_PATH = 'ckx_replication'
DATA_PATH = os.path.join(PROJECT_PATH, 'data')
RAW_PATH = os.path.join(DATA_PATH, 'raw')
CHN_GOV_YLD_NAME = 'cn_gov_usd_bond_yld_bbg.xlsx'
TREASURY_YLD_NAME = 'us_treasury_benchmark_yld_bbg.xlsx'

# ============================
# Data Loading Functions
# ============================
def load_bond_mapping_and_yield(raw_path):
    """
    Load bond mapping and yield data from Excel file.
    Returns:
        mapping (DataFrame): Bond mapping info
        yield_data (DataFrame): Bond yield time series
    """
    mapping = pd.read_excel(os.path.join(raw_path, 'cn_gov_usd_bond_mapping.xlsx'),
                            sheet_name='Mapping')
    yield_data = pd.read_excel(os.path.join(raw_path, CHN_GOV_YLD_NAME),
                               sheet_name='Data')
    yield_data = yield_data.sort_values(by='Date')
    return mapping, yield_data

# ============================
# Data Cleaning Functions for Chinese Government USD Bond
# ============================
def clean_mapping(mapping):
    """
    Clean and deduplicate the bond mapping DataFrame.
    """

    # Deduplicate mappings based on similar maturities
    new_mapping_flags = []
    issue_date_stored = None
    maturity_stored = None
    for index, row in mapping.iterrows():
        issue_date = row['Issue Date']
        maturity = row['Maturity']
        if maturity_stored is None:
            maturity_stored = maturity
            issue_date_stored = issue_date
            new_mapping_flags.append(True)
        else:
            # For each bond, it checks if its maturity date is within 20 days of the previously stored maturity date.
            if abs((maturity - maturity_stored).days) < 20:
                # If so, it keeps only the bond with the earlier issue date (i.e., the older bond).
                if issue_date < issue_date_stored:
                    maturity_stored = maturity
                    issue_date_stored = issue_date
                    new_mapping_flags.append(True)
                    new_mapping_flags[-2] = False
                else:
                    new_mapping_flags.append(False)
            # If not, it keeps the bond and updates the stored maturity and issue dates.
            else:
                maturity_stored = maturity
                issue_date_stored = issue_date
                new_mapping_flags.append(True)
    mapping = mapping[new_mapping_flags]
    return mapping

# ============================
# Yield Processing Functions
# ============================
def process_yields(mapping, yield_data):
    """
    Construct a DataFrame of USD bond yields by maturity date.
    """
    yield_chn_gov_usd = pd.DataFrame()
    for index, row in mapping.iterrows():
        isin_usd = row['ISIN']
        maturity = row['Maturity']
        maturity_str = maturity.strftime('%Y-%m-%d')
        yield_pair = yield_data[['Date', isin_usd]].copy()
        yield_pair = yield_pair[yield_pair.apply(lambda r: abs((r['Date'] - maturity).days) > 10, axis=1)]
        if yield_chn_gov_usd.empty:
            yield_chn_gov_usd = yield_pair[['Date', isin_usd]].rename(columns={isin_usd: maturity_str})
        else:
            yield_chn_gov_usd = pd.merge(yield_chn_gov_usd,
                                        yield_pair[['Date', isin_usd]],
                                        on='Date', how='outer') \
                                 .rename(columns={isin_usd: maturity_str})
    yield_chn_gov_usd.dropna(how='all', subset=[c for c in yield_chn_gov_usd.columns if c != 'Date'], inplace=True)
    yield_chn_gov_usd['Average'] = yield_chn_gov_usd.drop(columns=['Date']).mean(axis=1)

    # Get the maturities for individual bonds
    bonds  =[col for col in yield_chn_gov_usd.columns if col not in ['Date','Average']]

    return yield_chn_gov_usd, bonds

# ============================
# Interpolation Functions
# ============================
def get_weight(date, tenor, bonds):
    """
    Find the closest bond maturities before and after the target date+tenor, and compute interpolation weights.
    """
    date_tenor = pd.Timestamp(date) + pd.Timedelta(days=tenor*365)
    try:
        match_after = min([b for b in bonds if pd.Timestamp(b) > pd.Timestamp(date_tenor)], 
                         key=lambda x: pd.Timedelta(pd.Timestamp(x) - pd.Timestamp(date_tenor)).days)
        delta_after = pd.Timedelta(pd.Timestamp(match_after) - pd.Timestamp(date_tenor)).days
    except:
        match_after = None
        delta_after = None
    try:
        match_before = max([b for b in bonds if pd.Timestamp(b) < pd.Timestamp(date_tenor)], 
                          key=lambda x: pd.Timedelta(pd.Timestamp(x) - pd.Timestamp(date_tenor)).days)
        delta_before = pd.Timedelta(pd.Timestamp(match_before) - pd.Timestamp(date_tenor)).days
    except:
        match_before = None 
        delta_before = None
    weight_after, weight_before = None, None
    if not (match_after is None or match_before is None):
        if abs(delta_after) < 365 and abs(delta_before) < 365:
            weight_after = abs(delta_before)/(abs(delta_after)+abs(delta_before))
            weight_before = abs(delta_after)/(abs(delta_after)+abs(delta_before))
        elif abs(delta_after) >= 365 and abs(delta_before) < 365:
            if abs(delta_before) < 180:
                weight_after = 0
                weight_before = 1
        elif abs(delta_after) < 365 and abs(delta_before) >= 365:
            if abs(delta_after) < 180:
                weight_after = 1
                weight_before = 0
    elif match_after is None and delta_before is not None and abs(delta_before) < 180:
        weight_after = 0
        weight_before = 1
    elif match_before is None and delta_after is not None and abs(delta_after) < 180:
        weight_after = 1
        weight_before = 0
    return match_after, match_before, weight_after, weight_before

def linear_interpolation(data_input, date_list, tenor_list, bonds):
    """
    Interpolate yields for given tenors and dates using linear interpolation.
    """
    output = pd.DataFrame(date_list, columns=['Date']) 
    for tenor in tenor_list:
        interpolated_values = []
        for date in date_list:
            match_after, match_before, weight_after, weight_before = get_weight(date, tenor, bonds)
            if not (weight_after is None or weight_before is None):
                yield_diff_date = data_input[data_input['Date'] == date]
                if len(yield_diff_date) > 0:
                    yield_diff_after = yield_diff_date[match_after].values[0] if match_after in yield_diff_date.columns else np.nan
                    yield_diff_before = yield_diff_date[match_before].values[0] if match_before in yield_diff_date.columns else np.nan
                    if weight_after != 1 and weight_before != 1:    
                        yield_diff_weighted = (yield_diff_after * weight_after) + (yield_diff_before * weight_before)
                        if pd.isna(yield_diff_weighted):
                            if not pd.isna(yield_diff_after) and weight_after >= 0.5:
                                interpolated_values.append(yield_diff_after)
                            elif weight_before >= 0.5:
                                interpolated_values.append(yield_diff_before)
                            else:
                                interpolated_values.append(None)
                        else:
                            interpolated_values.append(yield_diff_weighted)
                    elif weight_after == 1:
                        interpolated_values.append(yield_diff_after)
                    elif weight_before == 1:
                        interpolated_values.append(yield_diff_before)
                else:
                    interpolated_values.append(None)
            else:
                interpolated_values.append(None)
        output[f'tenor_{tenor}'] = interpolated_values

    # Rename columns for clarity
    for tenor in tenor_list:
        output.rename(columns={f'tenor_{tenor}': f'yield_usd_cn_{tenor}'}, inplace=True)
    return output

# ============================
# Treasury Data Processing
# ============================
def process_ticker_data(df, ticker):
    """
    Extract and clean data for a given Bloomberg ticker.
    """
    ticker_cols = [col for col in df.columns if (col[0] == ticker) or (col[0] == "DATES")]
    ticker_data = df[ticker_cols].copy()
    ticker_data.columns = ticker_data.columns.get_level_values(1)
    ticker_data = ticker_data.dropna(how='all').reset_index(drop=True)
    ticker_data.columns = [col.replace('dropna(', '').replace('(dates=range(-20y,0d)))', '') 
                           for col in ticker_data.columns]
    return ticker_data

def load_and_merge_treasury_data(raw_path, treasury_file_name):
    """
    Load Bloomberg Treasury data from a single Excel file.
    """
    file_path = os.path.join(raw_path, treasury_file_name)
    df = pd.read_excel(file_path, header=[0,1])
    tickers = pd.unique(df.columns.get_level_values(0))
    
    all_dfs = []
    for ticker in tickers:
        if ticker == "DATES":
            continue
        df_ticker = process_ticker_data(df, ticker)
        df_copy = df_ticker.copy()
        new_columns = {}
        for col in df_copy.columns:
            if col != 'Date':
                new_name = col.replace('PX_LAST','px_last').replace('LAST_PRICE','px_last')\
                              .replace('PX_ASK','px_ask').replace('PX_BID','px_bid').replace('PX_MID','px_mid')
                new_columns[col] = f"{ticker}_{new_name}"
        df_copy.rename(columns=new_columns, inplace=True)
        all_dfs.append(df_copy)
    final_df = reduce(lambda left, right: pd.merge(left, right, on='Date', how='outer'), all_dfs)
    final_df.sort_values('Date', inplace=True)
    final_df.reset_index(drop=True, inplace=True)
    # Keep only px_last fields and rename for treasury yields
    field = 'px_last'
    final_df = final_df[[c for c in final_df.columns if c == 'Date' or c.endswith(field)]]
    final_df.rename(columns={
        'GB12 Govt_px_last':'Treasury_1Y',
        'GT2 Govt_px_last':'Treasury_2Y',
        'GT3 Govt_px_last':'Treasury_3Y',
        'GT5 Govt_px_last':'Treasury_5Y',
        'GT7 Govt_px_last':'Treasury_7Y',
        'GT10 Govt_px_last':'Treasury_10Y'
    }, inplace=True)
    return final_df

# ============================
# Plotting Function
# ============================
def plot_yield_spreads(output, data_path, start_date, end_date):
    """
    Plot and save USD-CN vs. US Treasury yield spreads for selected tenors.
    """
    plt.figure(figsize=(12, 8.2))
    matplotlib.rcParams.update({'font.size': 14})
    out_plt_df = output[(output['Date'] >= start_date) & (output['Date'] <= end_date)]
    for tenor in ['2','5','7','10']:
        if f'yield_usd_cn_{tenor}' in out_plt_df.columns and f'Treasury_{tenor}Y' in out_plt_df.columns:
            spread = out_plt_df[f'yield_usd_cn_{tenor}'] - out_plt_df[f'Treasury_{tenor}Y']
            plt.plot(out_plt_df['Date'], spread, label=f'{tenor} Years')
    ax = plt.gca()
    date_format = DateFormatter("%Ym%-m")
    ax.xaxis.set_major_formatter(date_format)
    plt.ylabel(r'Yield Spread, $y^{CHN,USD}_{i,t} - y^{Treasury}_{i,t}$ (%)')
    plt.axhline(0, color='black', linestyle='--')
    plt.legend(loc='lower left')
    sns.despine()
    plt.savefig(os.path.join(data_path, "figure", "chn_gov_usd_treasury_spreads.pdf"), bbox_inches='tight', pad_inches=.3)
    plt.show()

# ============================
# Main Script
# ============================
if __name__ == "__main__":
    # Load and clean bond mapping and yield data
    mapping, yield_chn_gov_usd = load_bond_mapping_and_yield(RAW_PATH)
    mapping = clean_mapping(mapping)

    # Process USD bond yields
    yield_chn_gov_usd, bonds = process_yields(mapping, yield_chn_gov_usd)

    # Interpolate yields
    date_list = pd.date_range(start='2017-10-27', end='2024-12-31', freq='D')
    tenor_list = [1,2,3,4,5,7,10]
    yield_chn_gov_usd_interpolated = linear_interpolation(yield_chn_gov_usd, date_list, tenor_list, bonds)

    # Load and merge Treasury data
    treasury_benchmark = load_and_merge_treasury_data(RAW_PATH, TREASURY_YLD_NAME)

    # Merge yield data with treasury data
    output = pd.merge(yield_chn_gov_usd_interpolated, treasury_benchmark, on='Date', how='outer')
    output.sort_values(by='Date', inplace=True)

    # Plot and save yield spreads
    plot_yield_spreads(output, DATA_PATH, start_date='2020-06-01', end_date='2024-12-31')
